{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a3efea70-191e-40a1-abe6-c8aa0c4535c4",
   "metadata": {},
   "source": [
    "# Auto Tensorization in BitBLAS\n",
    "\n",
    "Auto detect whether a given operator (gemm, conv2d, stencil, etc.) can be tensorized with given instructions' computation flow (MMA, DP4A, etc.)\n",
    "\n",
    "![image.png](./img/AutoTensorization.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7f17ee04-4406-4948-98e9-f61a42ed563d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import bitblas\n",
    "from bitblas import tvm\n",
    "from tvm import te, tir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa6c2458-7b4f-4951-addc-480df5fd9ef2",
   "metadata": {},
   "source": [
    "Get a convlution expression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3d3016d6-57c0-40d9-bee0-72582a3c9365",
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv2d_nhwc_hwio(n, f, h, w, c, kh, kw, s, d, p, in_dtype=\"float16\", out_dtype=\"float16\"):\n",
    "    A = te.placeholder((n, h, w, c), name=\"input\", dtype=in_dtype)\n",
    "    B = te.placeholder((kh, kw, c, f), name=\"weight\", dtype=in_dtype)\n",
    "\n",
    "    pad_shape = (n, h + 2 * p, w + 2 * p, c)\n",
    "    pad_value = tir.const(0.0, A.dtype)\n",
    "    pad = te.compute(\n",
    "        pad_shape,\n",
    "        lambda n, h, w, c: te.if_then_else(\n",
    "            tir.all(\n",
    "                h >= p,\n",
    "                w >= p,\n",
    "                h < pad_shape[1] - p,\n",
    "                w < pad_shape[2] - p,\n",
    "            ),\n",
    "            A[n, h - p, w - p, c],\n",
    "            pad_value,\n",
    "        ),\n",
    "        name=\"pad\",\n",
    "    )\n",
    "    kernel_h, kernel_w = kh, kw\n",
    "    stride_h, stride_w = s, s\n",
    "    dilation_h, dilation_w = d, d\n",
    "    out_h = (h + 2 * p - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1\n",
    "    out_w = (w + 2 * p - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1\n",
    "    out_shape = (n, out_h, out_w, f)\n",
    "    kh = te.reduce_axis((0, kernel_h), name=\"kh\")\n",
    "    kw = te.reduce_axis((0, kernel_w), name=\"kw\")\n",
    "    c = te.reduce_axis((0, c), name=\"c\")\n",
    "    C = te.compute(\n",
    "        out_shape,\n",
    "        lambda n, h, w, f: te.sum(\n",
    "            pad[\n",
    "                n,\n",
    "                h * stride_h + kh * dilation_h,\n",
    "                w * stride_w + kw * dilation_w,\n",
    "                c,\n",
    "            ] * B[kh, kw, c, f],\n",
    "            axis=[kh, kw, c],\n",
    "        ),\n",
    "        name=\"C\",\n",
    "    )\n",
    "    return te.create_prim_func([A, B, C])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8cd2f6ce-69b9-47b7-b466-790ea4712c3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "func = conv2d_nhwc_hwio(128, 64, 224, 224, 64, 1, 1, 2, 1, 3, \"float16\", \"float16\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a2b1e4e2-8a34-437f-a9f4-72d59a7eef33",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# from tvm.script import tir as T\n",
      "\n",
      "@T.prim_func\n",
      "def main(input: T.Buffer((128, 224, 224, 64), \"float16\"), weight: T.Buffer((1, 1, 64, 64), \"float16\"), C: T.Buffer((128, 115, 115, 64), \"float16\")):\n",
      "    T.func_attr({\"tir.noalias\": T.bool(True)})\n",
      "    # with T.block(\"root\"):\n",
      "    pad = T.alloc_buffer((128, 230, 230, 64), \"float16\")\n",
      "    for n, h, w, c in T.grid(128, 230, 230, 64):\n",
      "        with T.block(\"pad\"):\n",
      "            v_n, v_h, v_w, v_c = T.axis.remap(\"SSSS\", [n, h, w, c])\n",
      "            T.reads(input[v_n, v_h - 3, v_w - 3, v_c])\n",
      "            T.writes(pad[v_n, v_h, v_w, v_c])\n",
      "            pad[v_n, v_h, v_w, v_c] = T.if_then_else(3 <= v_h and 3 <= v_w and v_h < 227 and v_w < 227, input[v_n, v_h - 3, v_w - 3, v_c], T.float16(0))\n",
      "    for n, h, w, f, kh, kw, c in T.grid(128, 115, 115, 64, 1, 1, 64):\n",
      "        with T.block(\"C\"):\n",
      "            v_n, v_h, v_w, v_f, v_kh, v_kw, v_c = T.axis.remap(\"SSSSRRR\", [n, h, w, f, kh, kw, c])\n",
      "            T.reads(pad[v_n, v_h * 2 + v_kh, v_w * 2 + v_kw, v_c], weight[v_kh, v_kw, v_c, v_f])\n",
      "            T.writes(C[v_n, v_h, v_w, v_f])\n",
      "            with T.init():\n",
      "                C[v_n, v_h, v_w, v_f] = T.float16(0)\n",
      "            C[v_n, v_h, v_w, v_f] = C[v_n, v_h, v_w, v_f] + pad[v_n, v_h * 2 + v_kh, v_w * 2 + v_kw, v_c] * weight[v_kh, v_kw, v_c, v_f]\n"
     ]
    }
   ],
   "source": [
    "print(func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3615d7cd-7e81-4c67-91ce-1fe922fe11f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "target='nvidia/geforce-rtx-4090'\n"
     ]
    }
   ],
   "source": [
    "from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags\n",
    "from bitblas.base.arch import CUDA\n",
    "\n",
    "target = bitblas.auto_detect_nvidia_target()\n",
    "print(f\"{target=}\")\n",
    "arch = CUDA(target)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "981e0121-95fa-46d5-8c94-6992709f64f5",
   "metadata": {},
   "source": [
    "## Get Tensorized Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "fcf8597d-9131-4441-b612-782ed9d66a13",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# from tvm.script import tir as T\n",
      "\n",
      "@T.prim_func\n",
      "def main(input: T.Buffer((128, 224, 224, 64), \"float16\"), weight: T.Buffer((1, 1, 64, 64), \"float16\"), C: T.Buffer((128, 115, 115, 64), \"float16\")):\n",
      "    T.func_attr({\"dlight.tensorcore_prenormlized\": T.bool(True), \"tir.noalias\": T.bool(True)})\n",
      "    # with T.block(\"root\"):\n",
      "    pad = T.alloc_buffer((128, 230, 230, 64), \"float16\")\n",
      "    pad_reindex = T.alloc_buffer((1, 1692800, 64), \"float16\")\n",
      "    weight_reindex = T.alloc_buffer((1, 64, 64), \"float16\")\n",
      "    C_reindex = T.alloc_buffer((1, 1692800, 64), \"float16\")\n",
      "    for n, h, w, c in T.grid(128, 230, 230, 64):\n",
      "        with T.block(\"pad\"):\n",
      "            v_n, v_h, v_w, v_c = T.axis.remap(\"SSSS\", [n, h, w, c])\n",
      "            T.reads(input[v_n, v_h - 3, v_w - 3, v_c])\n",
      "            T.writes(pad[v_n, v_h, v_w, v_c])\n",
      "            pad[v_n, v_h, v_w, v_c] = T.if_then_else(3 <= v_h and 3 <= v_w and v_h < 227 and v_w < 227, input[v_n, v_h - 3, v_w - 3, v_c], T.float16(0))\n",
      "    for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(128, 115, 115, 1, 1, 64):\n",
      "        with T.block(\"pad_reindex_reindex\"):\n",
      "            v0, v1, v2, v3, v4, v5 = T.axis.remap(\"SSSSSS\", [ax0, ax1, ax2, ax3, ax4, ax5])\n",
      "            T.reads(pad[v0, v1 * 2 + v3, v2 * 2 + v4, v5])\n",
      "            T.writes(pad_reindex[0, v0 * 13225 + v1 * 115 + v2, v5])\n",
      "            pad_reindex[0, v0 * 13225 + v1 * 115 + v2, v5] = pad[v0, v1 * 2 + v3, v2 * 2 + v4, v5]\n",
      "    for ax0, ax1, ax2, ax3 in T.grid(64, 1, 1, 64):\n",
      "        with T.block(\"weight_reindex_reindex\"):\n",
      "            v0, v1, v2, v3 = T.axis.remap(\"SSSS\", [ax0, ax1, ax2, ax3])\n",
      "            T.reads(weight[v1, v2, v3, v0])\n",
      "            T.writes(weight_reindex[0, v3, v0])\n",
      "            weight_reindex[0, v3, v0] = weight[v1, v2, v3, v0]\n",
      "    for ax0, ax1, ax2, ax3 in T.grid(1, 1692800, 64, 64):\n",
      "        with T.block(\"C\"):\n",
      "            v0, v1, v2, v3 = T.axis.remap(\"SSSR\", [ax0, ax1, ax2, ax3])\n",
      "            T.reads(pad_reindex[0, v1, v3], weight_reindex[0, v3, v2])\n",
      "            T.writes(C_reindex[0, v1, v2])\n",
      "            with T.init():\n",
      "                C_reindex[0, v1, v2] = T.float16(0)\n",
      "            C_reindex[0, v1, v2] = C_reindex[0, v1, v2] + pad_reindex[0, v1, v3] * weight_reindex[0, v3, v2]\n",
      "    for ax0, ax1, ax2, ax3 in T.grid(128, 115, 115, 64):\n",
      "        with T.block(\"C_reindex\"):\n",
      "            v0, v1, v2, v3 = T.axis.remap(\"SSSS\", [ax0, ax1, ax2, ax3])\n",
      "            T.reads(C_reindex[0, v0 * 13225 + v1 * 115 + v2, v3])\n",
      "            T.writes(C[v0, v1, v2, v3])\n",
      "            C[v0, v1, v2, v3] = C_reindex[0, v0 * 13225 + v1 * 115 + v2, v3]\n"
     ]
    }
   ],
   "source": [
    "tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target)\n",
    "print(tensorized_func)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
